if __name__ == '__main__':
    import sys
    sys.path.append('../../')

import datetime
import time
import json
from quant import Order
from quant.markets import Functions
from quant.markets.functions import get_asset_value
from quant.utils import logging
from quant.exchanges.basics import ChannelBasic, HandlerBasic, ApiBasic, SpiBasic
from quant.exchanges.huobi_util import pong, unzip
from quant.exchanges.huobi_util import build_req, format_symbol, resume_symbol, sign
from quant.exchanges.util import spot_value_factor, Socket, RecentChecker, update_balance_default, prepare_placing, prepare_canceling, int_prec_to_precision, simulate_deal_spot
from quant.const import *


class HuobiSpotChannel(ChannelBasic):
    def init(self):
        host = ws_host + '/ws'
        self.socket = Socket(host, self.on_open, self.on_message, self.on_close)

    def on_message(self, ws, message):
        now = time.time()
        message = unzip(message)
        if message[2:6] == 'ping':
            # print('is ping')
            return pong(ws, message)
        self.markets.feed_raw(self.event, self.exchange, self.symbol, self.frequency, now, message)

    def subscribe_book(self, ws):
        symbol = self.symbol.replace('/', '')
        event = {'sub': 'market.{}.mbp.150'.format(symbol), 'id': '1234'}  # 150档都是100ms。20档又不够用。。
        message = json.dumps(event)
        ws.send(message)

    def subscribe_trade(self, ws):
        symbol = self.symbol.replace('/', '')
        event = {'sub': 'market.{}.trade.detail'.format(symbol), 'id': '1234'}
        message = json.dumps(event)
        ws.send(message)


class HuobiSpotHandler(HandlerBasic):
    def init(self):
        self.data_id_checker = RecentChecker(50)  # 首选！

    def process_book(self, routing_key, recv_time, raw):
        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        data = json.loads(raw)
        if 'subbed' in data:
            return self.info_welcome()

        # if 'action' not in data:
        #     logging.error('Huobi spot channel error: {}'.format(data))
        #     return

        tick = data['tick']
        server_id = tick['seqNum']
        server_time = data['ts'] / 1000
        self.markets.info_engine.push_process_recv(server_id, server_time, recv_time)

        if not self.data_id_checker.is_new(server_id):
            return

        book = self.book
        for p, v in tick['bids']:
            book['buy', float(p)] = float(v)
        for p, v in tick['asks']:
            book['sell', float(p)] = float(v)

        self.check_book(book)
        self.push_book(routing_key, recv_time, server_time, server_id, book)

    def process_trade(self, routing_key, recv_time, raw):
        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        data = json.loads(raw)
        if 'subbed' in data:
            return self.info_welcome()

        # if 'action' not in data:
        #     logging.error('Huobi spot channel error: {}'.format(data))
        #     return

        server_time = data['ts']
        data = data['tick']['data']
        server_id = data[-1]['id']
        self.markets.info_engine.push_process_recv(server_id, server_time, recv_time)

        if not self.data_id_checker.is_new(server_id):
            return

        trade_list = [[d['direction'], float(d['price']), float(d['amount'])] for d in data]
        self.push_trade(routing_key, recv_time, server_time, server_id, trade_list)


class HuobiSpotApi(ApiBasic):
    account_id: str

    def init(self):
        self.account_id = self.keys['account_id']

    def query_balance(self, symbol=None):
        path = '/v1/account/accounts/{}/balance'.format(self.account_id)
        params = {'account-id': self.account_id}
        req = build_req(rest_host, path, 'GET', params, self.api_key, self.secret_key)
        self.requesting.request_async(symbol, req, self._on_balance)

    def query_all_balance(self):
        path = '/v1/account/accounts/{}/balance'.format(self.account_id)
        params = {'account-id': self.account_id}
        req = build_req(rest_host, path, 'GET', params, self.api_key, self.secret_key)
        self.requesting.request_async(None, req, self._on_all_balance)

    def query_all_position(self):
        return

    def query_all_margin(self):
        return

    def query_open_orders(self, symbol=None):
        path = '/v1/order/openOrders'
        params = {
            'account-id': self.account_id,
            'symbol': format_symbol(symbol)
        }
        req = build_req(rest_host, path, 'GET', params, self.api_key, self.secret_key)
        self.requesting.request_async(symbol, req, self._on_open_orders)

    def query_all_open_orders(self):
        path = '/v1/order/openOrders'
        params = {
            'account-id': self.account_id,
        }
        req = build_req(rest_host, path, 'GET', params, self.api_key, self.secret_key)
        self.requesting.request_async(None, req, self._on_all_open_orders)

    def place_order(self, order):
        order = prepare_placing(order, self._create_cid())

        path = '/v1/order/orders/place'
        params = {
            'source': 'spot-api',
            'account-id': self.account_id,
            'amount': order.amount,
            'price': order.price,
            'symbol': format_symbol(order.symbol),
            'type': order_type_dic[order.side][order.order_type],
            'client-order-id': order.client_id,
        }

        if order.order_type == OrderType.Market:
            del params['price']
            if order.side == 'buy':
                asset0, asset1 = order.symbol.split('/')
                value0 = get_asset_value(asset0)
                value1 = get_asset_value(asset1)
                amount = value0 / value1 * order.amount
                params['amount'] = round(amount, 0)

        req = build_req(rest_host, path, 'POST', params, self.api_key, self.secret_key)
        self.account.order_processor.process_place_operation(order)
        self.requesting.request_async(order, req, self._on_place)
        return order.client_id

    def cancel_order(self, order):
        order = prepare_canceling(order)
        order_id = order.order_id

        if order_id is None:
            return logging.error('cancel order without oid: {}'.format(order))

        path = '/v1/order/orders/{}/submitcancel'.format(order_id)
        param = {
            'order-id': order_id,
            'symbol': order.symbol,
        }

        req = build_req(rest_host, path, 'POST', param, self.api_key, self.secret_key)
        self.account.order_processor.process_cancel_operation(order)
        self.requesting.request_async(order, req, self._on_cancel)

    def _on_balance(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_balance', symbol, response)

        if symbol is None:
            interests = self._asset_added
        else:
            interests = {*symbol.split('/')}

        j = response.json()
        balance = self.account.balance

        for dic in j['data']['list']:
            asset = dic['currency']
            if asset in interests:
                if dic['type'] == 'trade':
                    update_balance_default(balance, asset, free=float(dic['balance']))
                elif dic['type'] == 'frozen':
                    update_balance_default(balance, asset, frozen=float(dic['balance']))

        self._put_data(UserEvent.Balance, balance)

    def _on_all_balance(self, _, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_all_balance', _, response)

        j = response.json()
        balance = {}

        for dic in j['data']['list']:
            asset = dic['currency']
            bal = dic['balance']
            if bal == '0':
                continue
            if dic['type'] == 'trade':
                update_balance_default(balance, asset, free=float(dic['balance']))
            elif dic['type'] == 'frozen':
                update_balance_default(balance, asset, frozen=float(dic['balance']))

        self.account.balance.clear()
        self.account.balance.update(balance)

        self._put_data(UserEvent.Balance, balance)

    def _on_open_orders(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_open_orders', symbol, response)

        open_orders = []
        for dic in response.json()['data']:
            order = Order(
                symbol=symbol,
                side=dic['type'].split('-')[0],
                price=float(dic['price']),
                amount=float(dic['amount']) - float(dic['filled-amount']),
                client_id=dic['client-order-id'],  # 作为template时要注意，有的交易所不会赋值手动单的cid.
                order_id=dic['id'],
                status=OrderStatus.Pending,
            )
            open_orders.append(order)

        self.account.order_processor.process_open_orders(open_orders)
        self._put_data(UserEvent.OpenOrders, self.account.orders)

    def _on_all_open_orders(self, param, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_open_orders', None, response)

        open_orders = []
        for dic in response.json()['data']:
            order = Order(
                symbol=resume_symbol(dic['symbol']),
                side=dic['type'].split('-')[0],
                price=float(dic['price']),
                amount=float(dic['amount']) - float(dic['filled-amount']),
                client_id=dic['client-order-id'],  # 作为template时要注意，有的交易所不会赋值手动单的cid.
                order_id=dic['id'],
                status=OrderStatus.Pending,
            )
            open_orders.append(order)

        self.account.order_processor.process_open_orders(open_orders)
        self._put_data(UserEvent.OpenOrders, self.account.orders)

        if open_orders:
            self.account.order_processor.process_open_orders(open_orders)
            self._put_data(UserEvent.OpenOrders, self.account.orders)

    def _on_place(self, order, response):
        order = order.copy()
        data = response.json()

        if int(response.status_code) // 100 != 2:
            self._fail_request('place_order', order, response)
            order.status = OrderStatus.PlaceFailed

        elif data['status'] != 'ok':
            self._fail_request('place_order', order, response)
            order.status = OrderStatus.PlaceFailed

        else:
            order.order_id = data['data']
            order.status = OrderStatus.Pending
            if order.order_type == OrderType.Market:
                order.status = OrderStatus.FullyFilled

        order = self.account.order_processor.update_order(order)
        self._put_data(UserEvent.Order, order)

    def _on_cancel(self, order, response):
        order = order.copy()

        if int(response.status_code) // 100 != 2:
            self._fail_request('cancel_order', order, response)

        res = response.json()
        if 'status' not in res or res['status'] != 'ok':
            self._fail_request('cancel_order', order, response)

        order.status = OrderStatus.Canceled
        order = self.account.order_processor.update_order(order)
        self._put_data(UserEvent.Order, order)

    def get_account_id(self):
        path = '/v1/account/accounts'
        req = build_req(rest_host, path, 'GET', {}, self.api_key, self.secret_key)
        res = self.requesting.request(req)
        return res.json()

    def transfer_swap(self, amount, back_to_spot=False):
        pass


class HuobiSpotSpi(SpiBasic):
    _listen_key = ''
    _extend_key_started = False

    def add_symbol(self, symbol):
        self.connect_once()
        super().add_symbol(symbol)
        if self._socket.is_connected():
            self._on_login(self._socket)

    def _create_socket(self):
        host = ws_host + '/ws/v2'
        socket = Socket(host, self._on_open, self._on_message, self._on_close)
        return socket

    def _on_open(self, ws):
        sign_params = {
            "accessKey": self.api_key,
            "signatureMethod": "HmacSHA256",
            "signatureVersion": "2.1",
            "timestamp": datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S'),
        }
        sig = sign(ws_host, '/ws/v2', 'GET', sign_params, self.secret_key)

        sign_params['authType'] = 'api'
        sign_params['signature'] = sig
        data = {
            "action": "req",
            "ch": "auth",
            "params": sign_params
        }

        message = json.dumps(data)
        ws.send(message)

    def _on_login(self, ws):
        data = {"action": "sub", "ch": 'accounts.update#2'}
        message = json.dumps(data)
        self._socket.send(message)

        data = {"action": "sub", "ch": 'orders#*'}
        message = json.dumps(data)
        self._socket.send(message)

    def _on_message(self, ws, message):
        data = json.loads(message)

        if data['action'] == 'push':
            ch = data['ch']
            if ch.startswith('orders'):
                self._handle_order(data)
            elif ch.startswith('accounts'):
                self._handle_balance(data)

        elif 'ch' in data and data['ch'] == 'auth':
            self._on_login(ws)
            return

        elif data['action'] == 'ping':
            self._handle_ping(data)

        elif data['action'] == 'sub':
            pass

        else:
            logging.warn('HuobiSpotSpi unknown msg:{}'.format(message))

    def _handle_balance(self, data):
        data = data['data']

        balance = self.account.balance
        asset = data['currency']
        if asset in self._asset_added:
            bal = float(data['balance'])
            free = float(data['available'])
            update_balance_default(balance, asset, free, bal-free)

        self._put_data(UserEvent.Balance, balance)

    def _handle_order(self, data):
        data = data['data']
        event_type = data['eventType']
        symbol = resume_symbol(data['symbol'])

        if symbol not in self._symbol_added:
            return

        if event_type == 'creation':
            amount = float(data['orderSize'])
        else:
            amount = float(data['remainAmt'])

        client_id = data['clientOrderId']
        old = self.account.orders.get(client_id)
        if old is not None:
            deal = old.amount - amount
        else:
            deal = 0

        order = Order(
            symbol=symbol,
            side=data['type'].split('-')[0],
            price=float(data['orderPrice']),
            amount=amount,
            order_id=str(data['orderId']),
            client_id=client_id,
            status=status_dic[data['orderStatus']],
        )
        order = self.account.order_processor.update_order(order)
        self._put_data(UserEvent.Order, order)

        if deal != 0:
            self._put_deal(order, deal)

    def _handle_ping(self, data):
        data['action'] = 'pong'
        message = json.dumps(data)
        self._socket.send(message)


class HuobiSpotFunctions(Functions):
    def get_value_factor(self, symbol):
        return spot_value_factor(symbol)

    def get_all_ticker(self):
        url = rest_host + '/v2/settings/common/symbols'
        response = self.session.get(url, timeout=3)
        if int(response.status_code) // 100 != 2:
            logging.warn('Huobi func-api error: {}'.format(response.text))
            return {}
        valid = response.json()
        valid = {d['dn'].lower() for d in valid['data'] if d['state'] == 'online'}

        url = rest_host + '/market/tickers'
        response = self.session.get(url, timeout=3)
        if int(response.status_code) // 100 != 2:
            logging.warn('Huobi func-api error: {}'.format(response.text))
            return {}

        result = {}
        _usdt = get_asset_value('usdt')
        data = response.json()
        for d in data['data']:
            contract = d['symbol']
            money = contract[-4:]
            if money != 'usdt':
                continue

            symbol = resume_symbol(contract)
            if symbol not in valid:
                continue

            result[symbol] = {
                'price': float(d['close']),
                'vol': float(d['vol']) * _usdt,
                'value': float(d['close']) * _usdt,
                'bid': float(d['bid']),
                'ask': float(d['ask']),
            }

        return result

    def get_all_precision(self):
        url = rest_host + '/v2/settings/common/symbols'
        response = self.session.get(url, timeout=3)
        data = response.json()
        if int(response.status_code) // 100 != 2:
            logging.warn('Huobi func-api error: {}'.format(response.text))
            return {}

        result = {}
        for d in data['data']:
            contract = d['sc']
            if not contract.endswith('usdt'):
                continue

            symbol = resume_symbol(contract)

            price_unit = d['tpp']
            amount_unit = d['tap']
            price_unit = int_prec_to_precision(price_unit)
            amount_unit = int_prec_to_precision(amount_unit)
            result[symbol] = (price_unit, amount_unit)

        return result

    def get_all_contract_size(self):
        all_ticker = self.get_all_ticker()
        result = {s: 1 for s in all_ticker}
        return result

    def get_recent_trade(self, symbol, _=None):
        url = rest_host + '/market/history/trade'
        param = {
            'symbol': format_symbol(symbol),
            'size': 2000
        }
        response = self.session.get(url, timeout=3, params=param)

        if int(response.status_code) // 100 != 2:
            logging.warn('Huobi func-api error: {}'.format(response.text))
            return []

        data = response.json()
        if 'data' not in data:
            return []

        result = [
            (
                d['id'],
                d['direction'],
                float(d['price']),
                float(d['amount']),
                d['ts'] / 1000,
            )
            for dic in data['data'] for d in dic['data']
        ]
        return result

    def simulate_deal(self, order, deal_amount, mid_price, account):
        return simulate_deal_spot(order, deal_amount, mid_price, account)


rest_host = 'https://api-aws.huobi.pro'
ws_host = 'wss://api-aws.huobi.pro'
order_type_dic = {
    'buy': {
        OrderType.PostOnly: 'buy-limit-maker',
        OrderType.Limit: 'buy-limit',
        OrderType.Market: 'buy-market'
    },
    'sell': {
        OrderType.PostOnly: 'sell-limit-maker',
        OrderType.Limit: 'sell-limit',
        OrderType.Market: 'sell-market'
    }
}
status_dic = {
    'submitted': OrderStatus.Pending,
    'canceled': OrderStatus.Canceled,
    'partial-filled': OrderStatus.PartFilled,
    'filled': OrderStatus.FullyFilled,
    'partial-canceled': OrderStatus.Canceled
}


if __name__ == '__main__':
    from quant.markets import Markets
    from quant.utils import set_test_mode
    from utils import print_trades
    set_test_mode()

    def try_channel():
        mar = Markets()
        mar.add_market('Trade', 'Huobi', 'eos/usdt')

        print_trades(mar)
        input(':')

    def try_api():
        key = {
            'api_key': 'uymylwhfeg-c779b961-a76add60-88b57',
            'secret_key': '53dd5925-7d2de78b-4ced8f41-4e560',
            'account_id': '54403912',
        }

        api = HuobiSpotApi(key)
        api.account.info_engine.show_user_data()
        api.account.info_engine.show_problems()


        # api.place_order(Order('eth/usdt', 'buy', 1000, 0.02))
        # api.join()

        # api.query_all_open_orders()
        # api.join()

        # for o in api.account.orders.values():
        #     print('cancel:', o)
        #     api.cancel_order(o)
        #     break
        # api.join()

        # api.cancel_order(Order('eth/usdt', order_id='708618391554905'))
        # api.join()

        api.query_all_balance()
        api.join()

    def try_spi():
        key = {
            'api_key': 'uymylwhfeg-c779b961-a76add60-88b57',
            'secret_key': '53dd5925-7d2de78b-4ced8f41-4e560',
            'account_id': '54403912',
        }

        spi = HuobiSpotSpi(key)
        # spi.account.info_engine.show_user_data()
        spi.account.info_engine.show_problems()
        spi.connect_once()
        spi.add_symbol('eth/usdt')

        def on_ud(ud):
            print(ud.api_type, ud.data)
        spi.account.subscribe(UserEvent.Order, on_ud)

        input(':')
        api = HuobiSpotApi(key, spi.account)
        api.place_order(Order('eth/usdt', 'buy', 1300, 0.01))

        while True:
            input(':')

    def try_function():
        func = HuobiSpotFunctions()

        ticker = func.get_all_ticker()
        ticker = [(s, dic) for s, dic in ticker.items()]
        ticker.sort(key=lambda x: x[1]['vol'], reverse=True)
        for t in ticker:
            print(t)

        # prec = func.get_all_precision()
        # for i in prec.items():
        #     print(i)

        # from quant.markets.functions import get_precision
        # a = get_precision('Huobi', 'icp/usdt')
        # print(a)

        # trade = func.get_recent_trade('pi/usdt')
        # for i in trade:
        #     print(i)

    def account_id():
        from keys import get_key

        for i in range(1, 6):
            name = 'hb_rw{}'.format(i)
            key = get_key(name)
            api = HuobiSpotApi(key)
            # a = api.get_account_id()
            # print(name, a)

            api.account.info_engine.show_user_data()
            api.query_all_balance()
            api.join()

    account_id()






